{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#  Speedup and Quantize any Diffusion Model"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "vscode": {
     "languageId": "raw"
    }
   },
   "source": [
    "<a target=\"_blank\" href=\"https://colab.research.google.com/github/PrunaAI/pruna/blob/v|version|/docs/tutorials/diffusion_quantization_acceleration.ipynb\">\n",
    "    <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
    "</a>"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "raw_mimetype": "text/restructuredtext",
    "vscode": {
     "languageId": "raw"
    }
   },
   "source": [
    "This tutorial demonstrates how to use the ``pruna`` package to optimize both the latency and the memory footprint of any diffusion model from the diffusers package.\n",
    "We will use the ``Flux Dev`` model as an example, but this tutorial is working on any stable diffusion or flux model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# if you are not running the latest version of this tutorial, make sure to install the matching version of pruna\n",
    "# the following command will install the latest version of pruna\n",
    "%pip install pruna"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. Loading the Diffusion Model\n",
    "\n",
    "First, load your diffusion model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from diffusers import FluxPipeline\n",
    "\n",
    "pipe = FluxPipeline.from_pretrained(\"black-forest-labs/FLUX.1-schnell\", torch_dtype=torch.bfloat16).to(\"cuda\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. Initializing the Smash Config"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "raw_mimetype": "text/restructuredtext",
    "vscode": {
     "languageId": "raw"
    }
   },
   "source": [
    "Next, initialize the smash_config (we make use, here, of the :doc:`hqq-diffusers </compression>` and :doc:`torch-compile </compression>` algorithms)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pruna import SmashConfig\n",
    "\n",
    "smash_config = SmashConfig([\"hqq_diffusers\", \"torch_compile\"])\n",
    "# smash_config.add({'torch_compile_mode': 'max-autotune'}) # Uncomment to enable extra speedups"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3. Smashing the Model\n",
    "\n",
    "Now, smash the model. This can take up to 30 seconds."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pruna import smash\n",
    "\n",
    "# Smash the model\n",
    "pipe = smash(\n",
    "    model=pipe,\n",
    "    smash_config=smash_config,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4. Running the Model\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, run the model to generate the image, note there will be a warmup the first time you run it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "ename": "",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n",
      "\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n",
      "\u001b[1;31mClick <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. \n",
      "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
     ]
    }
   ],
   "source": [
    "# Define the prompt\n",
    "prompt = \"a smiling cat dancing on a table. Miyazaki style\"\n",
    "\n",
    "# Display the result\n",
    "pipe(prompt).images[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Wrap Up"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "raw_mimetype": "text/restructuredtext",
    "vscode": {
     "languageId": "raw"
    }
   },
   "source": [
    "Congratulations! You've optimized a diffusion model using HQQ quantization and TorchCompile! The quantized model uses less memory and runs faster while maintaining good quality. You can try different settings like weight bits and group size to find the best balance between size and quality.\n",
    "\n",
    "Want more optimization techniques? Check out our other tutorials!"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pruna",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
